{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from cot import Collection\n",
    "from cot.generate import FRAGMENTS\n",
    "from rich.pretty import pprint\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot.create import create_thoughtsource_100, create_thoughtsource_1\n",
    "from sprint_utils import get_answer_extraction, join_strings #model_sprint_thoughtsource"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_sprint_thoughtsource_instruction(collection, dataset_name_splits, api_service_model, instruction_key=None):\n",
    "    # collection = create_thoughtsource_100()\n",
    "    collection = collection\n",
    "    for dataset_name, split in dataset_name_splits:\n",
    "        print(dataset_name)\n",
    "        answer_extraction_key = get_answer_extraction(collection[dataset_name][split][0][\"type\"], collection[dataset_name][split][0][\"choices\"])\n",
    "        for api_service, model, api_time_interval in api_service_model:\n",
    "\n",
    "            config={\n",
    "                \"instruction_keys\": instruction_key,\n",
    "                \"cot_trigger_keys\": None,\n",
    "                \"answer_extraction_keys\": [answer_extraction_key], # Therefore, among A through C/D/E/F, the answer is'\n",
    "                \"author\" : \"thoughtsource\",\n",
    "                \"api_service\": api_service,\n",
    "                \"engine\": model,\n",
    "                \"temperature\": 0,\n",
    "                \"max_tokens\": 512,\n",
    "                \"api_time_interval\": api_time_interval,\n",
    "                \"verbose\": False,\n",
    "                \"warn\": False,\n",
    "            }\n",
    "        \n",
    "            collection.generate(name=dataset_name, split=split,config=config)\n",
    "    collection.evaluate()\n",
    "    collection.dump(\"thoughtsource_100\" + \"_\" + api_service + \"_\" + model.replace(\"/\", \"_\") + join_strings(config[\"instruction_keys\"]) + \".json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name_splits = [(\"commonsense_qa\", \"validation\"),\n",
    "                        (\"med_qa\", \"test\"),\n",
    "                        (\"medmc_qa\", \"validation\"),\n",
    "                        (\"open_book_qa\", \"test\"),\n",
    "                        (\"strategy_qa\", \"train\"),\n",
    "                        (\"worldtree\", \"test\")]\n",
    "\n",
    "# api_service_model = [(\"cohere\", \"command-xlarge-nightly\", 1)]\n",
    "api_service_model = [(\"openai\", \"text-davinci-003\", 1)]\n",
    "# api_service_model = [(\"mock_api\", \"mock_model\", 0)]\n",
    "instruction_key = [\"qa-04\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "collection = Collection.from_json(\"./thoughtsource_data/thoughtsource_1_empty.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "commonsense_qa\n",
      "Generating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8f642a440db24a389cc637adfce2480e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "med_qa\n",
      "Generating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99bfbf09535646229d491bd8ff9f96d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "medmc_qa\n",
      "Generating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "113435b7fb4d4e34a6b230513adaa5f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "open_book_qa\n",
      "Generating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba20af49c63c4b1fb37c6344eccbc2ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "strategy_qa\n",
      "Generating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa180b8a209b49b3b24b1037edee2e05",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "worldtree\n",
      "Generating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a61d66a73628458e84726942f132e5fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e2afbbbebd74c76a36e5b7cbfdd9972",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad3e1705b86c486db0a0022a3d172de0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00d02bcbe620406c9a9b2571c6b13826",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d0f7d922a5043f8901bb1fdbb380d73",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dbf8a03284de4f81887dfd9ea1ad545c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c1f5f64a92845e6bea06e825f8efd1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "29f7bd13ca4c4e46b755148df476e2c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48c620d9edce40be97dfcc6bf25998a1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4755e51f4314dc2a8fdbd8d0086735d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39ed4fe993c447c5a05c4b970fdf666d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bfbc72a9dd804290a90c64b004a30baf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83d1aff2d0bb465babe409a0fbde785e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_sprint_thoughtsource_instruction(collection,dataset_name_splits, api_service_model, instruction_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "884f62d00f2e41519d912010f3feebf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61ccb0e60d7c41fd9e35e60273fe7e2c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5cadedcc7f53477d82e2adfbdf925afc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c2c6e8ad75f4a6c95cf34fd1b54560f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24210905120143eeb5901ea01a50d685",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4af0703f8fde4d39a93fd4fb98868138",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'commonsense_qa': {'validation': {'accuracy': {'text-davinci-003': {'qa-04_None_kojima-A-E': 0.0}}}},\n",
       " 'med_qa': {'test': {'accuracy': {'text-davinci-003': {'qa-04_None_kojima-A-E': 1.0}}}},\n",
       " 'medmc_qa': {'validation': {'accuracy': {'text-davinci-003': {'qa-04_None_kojima-A-D': 0.0}}}},\n",
       " 'open_book_qa': {'test': {'accuracy': {'text-davinci-003': {'qa-04_None_kojima-A-D': 1.0}}}},\n",
       " 'strategy_qa': {'train': {'accuracy': {'text-davinci-003': {'qa-04_None_kojima-yes-no': 0.0}}}},\n",
       " 'worldtree': {'test': {'accuracy': {'text-davinci-003': {'qa-04_None_kojima-A-D': 0.0}}}}}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval = collection.evaluate()\n",
    "eval"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "ccbfa654f25866afe66a1c016d0b518a994fbe20a93c2d8b432dbe114020e2d2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
